{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pknVo1kM2wI2"
      },
      "source": [
        "##### Copyright 2021 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "SoFqANDE222Y"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6x1ypzczQCwy"
      },
      "source": [
        "# Vertex AI Training with TFX and Vertex Pipelines\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_445qeKq8e3-"
      },
      "source": [
        "<div class=\"devsite-table-wrapper\"><table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "<td><a target=\"_blank\" href=\"https://www.tensorflow.org/tfx/tutorials/tfx/gcp/vertex_pipelines_vertex_training\">\n",
        "<img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\"/>View on TensorFlow.org</a></td>\n",
        "<td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tfx/blob/master/docs/tutorials/tfx/gcp/vertex_pipelines_vertex_training.ipynb\">\n",
        "<img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\">Run in Google Colab</a></td>\n",
        "<td><a target=\"_blank\" href=\"https://github.com/tensorflow/tfx/tree/master/docs/tutorials/tfx/gcp/vertex_pipelines_vertex_training.ipynb\">\n",
        "<img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\">View source on GitHub</a></td>\n",
        "<td><a href=\"https://storage.googleapis.com/tensorflow_docs/tfx/docs/tutorials/tfx/gcp/vertex_pipelines_vertex_training.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a></td>\n",
        "<td><a href=\"https://console.cloud.google.com/mlengine/notebooks/deploy-notebook?q=download_url%3Dhttps%253A%252F%252Fraw.githubusercontent.com%252Ftensorflow%252Ftfx%252Fmaster%252Fdocs%252Ftutorials%252Ftfx%252Fgcp%252Fvertex_pipelines_vertex_training.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Run in Google Cloud AI Platform Notebook</a></td>\n",
        "</table></div>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_VuwrlnvQJ5k"
      },
      "source": [
        "This notebook-based tutorial will create and run a TFX pipeline which trains an\n",
        "ML model using Vertex AI Training service.\n",
        "\n",
        "This notebook is based on the TFX pipeline we built in\n",
        "[Simple TFX Pipeline for Vertex Pipelines Tutorial](https://www.tensorflow.org/tfx/tutorials/tfx/gcp/vertex_pipelines_simple).\n",
        "If you have not read that tutorial yet, you should read it before proceeding\n",
        "with this notebook.\n",
        "\n",
        "You can train models on Vertex AI using AutoML, or use custom training. In\n",
        "custom training, you can select many different machine types to power your\n",
        "training jobs, enable distributed training, use hyperparameter tuning, and\n",
        "accelerate with GPUs.\n",
        "\n",
        "In this tutorial, we will use Vertex AI Training with custom jobs to train\n",
        "a model in a TFX pipeline.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S5Pv2qm3wfpL"
      },
      "source": [
        "This notebook is intended to be run on\n",
        "[Google Colab](https://colab.research.google.com/notebooks/intro.ipynb) or on\n",
        "[AI Platform Notebooks](https://cloud.google.com/ai-platform-notebooks). If you\n",
        "are not using one of these, you can simply click \"Run in Goolge Colab\" button\n",
        "above.\n",
        "\n",
        "## Set up\n",
        "If you have completed\n",
        "[Simple TFX Pipeline for Vertex Pipelines Tutorial](https://www.tensorflow.org/tfx/tutorials/tfx/gcp/vertex_pipelines_simple),\n",
        "you will have a working GCP project and a GCS bucket and that is all we need\n",
        "for this tutorial. Please read the preliminary tutorial first if you missed it."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fwZ0aXisoBFW"
      },
      "source": [
        "### Install python packages"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WC9W_S-bONgl"
      },
      "source": [
        "We will install required Python packages including TFX and KFP to author ML\n",
        "pipelines and submit jobs to Vertex Pipelines."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iyQtljP-qPHY"
      },
      "outputs": [],
      "source": [
        "# Use the latest version of pip.\n",
        "!pip install --upgrade pip\n",
        "!pip install --upgrade tfx==0.30.0 kfp==1.6.1"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EwT0nov5QO1M"
      },
      "source": [
        "#### Did you restart the runtime?\n",
        "\n",
        "If you are using Google Colab, the first time that you run\n",
        "the cell above, you must restart the runtime by clicking\n",
        "above \"RESTART RUNTIME\" button or using \"Runtime > Restart\n",
        "runtime ...\" menu. This is because of the way that Colab\n",
        "loads packages."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-CRyIL4LVDlQ"
      },
      "source": [
        "If you are not on Colab, you can restart runtime with following cell."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KHTSzMygoBF6"
      },
      "outputs": [],
      "source": [
        "# docs_infra: no_execute\n",
        "import sys\n",
        "if not 'google.colab' in sys.modules:\n",
        "  # Automatically restart kernel after installs\n",
        "  import IPython\n",
        "  app = IPython.Application.instance()\n",
        "  app.kernel.do_shutdown(True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gckGHdW9iPrq"
      },
      "source": [
        "### Login in to Google for this notebook\n",
        "If you are running this notebook on Colab, authenticate with your user account:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kZQA0KrfXCvU"
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "if 'google.colab' in sys.modules:\n",
        "  from google.colab import auth\n",
        "  auth.authenticate_user()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aaqJjbmk6o0o"
      },
      "source": [
        "**If you are on AI Platform Notebooks**, authenticate with Google Cloud before\n",
        "running the next section, by running\n",
        "```sh\n",
        "gcloud auth login\n",
        "```\n",
        "**in the Terminal window** (which you can open via **File** > **New** in the\n",
        "menu). You only need to do this once per notebook instance."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3_SveIKxaENu"
      },
      "source": [
        "Check the package versions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Xd-iP9wEaENu"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "print('TensorFlow version: {}'.format(tf.__version__))\n",
        "from tfx import v1 as tfx\n",
        "print('TFX version: {}'.format(tfx.__version__))\n",
        "import kfp\n",
        "print('KFP version: {}'.format(kfp.__version__))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aDtLdSkvqPHe"
      },
      "source": [
        "### Set up variables\n",
        "\n",
        "We will set up some variables used to customize the pipelines below. Following\n",
        "information is required:\n",
        "\n",
        "* GCP Project id. See\n",
        "[Identifying your project id](https://cloud.google.com/resource-manager/docs/creating-managing-projects#identifying_projects).\n",
        "* GCP Region to run pipelines. For more information about the regions that\n",
        "Vertex Pipelines is available in, see the\n",
        "[Vertex AI locations guide](https://cloud.google.com/vertex-ai/docs/general/locations#feature-availability).\n",
        "* Google Cloud Storage Bucket to store pipeline outputs.\n",
        "\n",
        "**Enter required values in the cell below before running it**.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EcUseqJaE2XN"
      },
      "outputs": [],
      "source": [
        "GOOGLE_CLOUD_PROJECT = ''     # <--- ENTER THIS\n",
        "GOOGLE_CLOUD_REGION = ''      # <--- ENTER THIS\n",
        "GCS_BUCKET_NAME = ''          # <--- ENTER THIS\n",
        "\n",
        "if not (GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_REGION and GCS_BUCKET_NAME):\n",
        "    from absl import logging\n",
        "    logging.error('Please set all required parameters.')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GAaCPLjgiJrO"
      },
      "source": [
        "Set `gcloud` to use your project."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VkWdxe4TXRHk"
      },
      "outputs": [],
      "source": [
        "!gcloud config set project {GOOGLE_CLOUD_PROJECT}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CPN6UL5CazNy"
      },
      "outputs": [],
      "source": [
        "PIPELINE_NAME = 'penguin-vertex-training'\n",
        "\n",
        "# Path to various pipeline artifact.\n",
        "PIPELINE_ROOT = 'gs://{}/pipeline_root/{}'.format(GCS_BUCKET_NAME, PIPELINE_NAME)\n",
        "\n",
        "# Paths for users' Python module.\n",
        "MODULE_ROOT = 'gs://{}/pipeline_module/{}'.format(GCS_BUCKET_NAME, PIPELINE_NAME)\n",
        "\n",
        "# Paths for users' data.\n",
        "DATA_ROOT = 'gs://{}/data/{}'.format(GCS_BUCKET_NAME, PIPELINE_NAME)\n",
        "\n",
        "# This is the path where your model will be pushed for serving.\n",
        "SERVING_MODEL_DIR = 'gs://{}/serving_model/{}'.format(\n",
        "    GCS_BUCKET_NAME, PIPELINE_NAME)\n",
        "\n",
        "print('PIPELINE_ROOT: {}'.format(PIPELINE_ROOT))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8F2SRwRLSYGa"
      },
      "source": [
        "### Prepare example data\n",
        "We will use the same\n",
        "[Palmer Penguins dataset](https://allisonhorst.github.io/palmerpenguins/articles/intro.html)\n",
        "as\n",
        "[Simple TFX Pipeline Tutorial](https://www.tensorflow.org/tfx/tutorials/tfx/penguin_simple).\n",
        "\n",
        "There are four numeric features in this dataset which were already normalized\n",
        "to have range [0,1]. We will build a classification model which predicts the\n",
        "`species` of penguins."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "11J7XiCq6AFP"
      },
      "source": [
        "We need to make our own copy of the dataset. Because TFX ExampleGen reads\n",
        "inputs from a directory, we need to create a directory and copy dataset to it\n",
        "on GCS."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4fxMs6u86acP"
      },
      "outputs": [],
      "source": [
        "!gsutil cp gs://download.tensorflow.org/data/palmer_penguins/penguins_processed.csv {DATA_ROOT}/"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ASpoNmxKSQjI"
      },
      "source": [
        "Take a quick look at the CSV file."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-eSz28UDSnlG"
      },
      "outputs": [],
      "source": [
        "!gsutil cat {DATA_ROOT}/penguins_processed.csv | head"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nH6gizcpSwWV"
      },
      "source": [
        "## Create a pipeline\n",
        "\n",
        "Our pipeline will be very similar to the pipeline we created in\n",
        "[Simple TFX Pipeline for Vertex Pipelines Tutorial](https://www.tensorflow.org/tfx/tutorials/tfx/gcp/vertex_pipelines_simple).\n",
        "The pipeline will consists of three components, CsvExampleGen, Trainer and\n",
        "Pusher. But we will use a special Trainer component which is used to move\n",
        "training workloads to Vertex AI.\n",
        "\n",
        "TFX provides a special `Trainer` to submit training jobs to Vertex AI Training\n",
        "service, and all we have to do is just using `Trainer` in the extension module\n",
        "instead of the standard `Trainer` component along with some required GCP\n",
        "parameters.\n",
        "\n",
        "In this tutorial, we will run Vertex AI Training jobs only using CPUs first,\n",
        "and then with a GPU."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lOjDv93eS5xV"
      },
      "source": [
        "### Write model code.\n",
        "\n",
        "The model itself is almost similar to the model in\n",
        "[Simple TFX Pipeline Tutorial](https://www.tensorflow.org/tfx/tutorials/tfx/penguin_simple).\n",
        "\n",
        "We will add `_get_distribution_strategy()` function which creates a\n",
        "[TensorFlow distribution strategy](https://www.tensorflow.org/guide/distributed_training)\n",
        "and it is used in `run_fn` to use MirroredStrategy if GPU is available."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aES7Hv5QTDK3"
      },
      "outputs": [],
      "source": [
        "_trainer_module_file = 'penguin_trainer.py'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Gnc67uQNTDfW"
      },
      "outputs": [],
      "source": [
        "%%writefile {_trainer_module_file}\n",
        "\n",
        "# Copied from https://www.tensorflow.org/tfx/tutorials/tfx/penguin_simple and\n",
        "# slightly modified run_fn() to add distribution_strategy.\n",
        "\n",
        "from typing import List\n",
        "from absl import logging\n",
        "import tensorflow as tf\n",
        "from tensorflow import keras\n",
        "from tensorflow_metadata.proto.v0 import schema_pb2\n",
        "from tensorflow_transform.tf_metadata import schema_utils\n",
        "\n",
        "from tfx import v1 as tfx\n",
        "from tfx_bsl.public import tfxio\n",
        "\n",
        "_FEATURE_KEYS = [\n",
        "    'culmen_length_mm', 'culmen_depth_mm', 'flipper_length_mm', 'body_mass_g'\n",
        "]\n",
        "_LABEL_KEY = 'species'\n",
        "\n",
        "_TRAIN_BATCH_SIZE = 20\n",
        "_EVAL_BATCH_SIZE = 10\n",
        "\n",
        "# Since we're not generating or creating a schema, we will instead create\n",
        "# a feature spec.  Since there are a fairly small number of features this is\n",
        "# manageable for this dataset.\n",
        "_FEATURE_SPEC = {\n",
        "    **{\n",
        "        feature: tf.io.FixedLenFeature(shape=[1], dtype=tf.float32)\n",
        "        for feature in _FEATURE_KEYS\n",
        "    }, _LABEL_KEY: tf.io.FixedLenFeature(shape=[1], dtype=tf.int64)\n",
        "}\n",
        "\n",
        "\n",
        "def _input_fn(file_pattern: List[str],\n",
        "              data_accessor: tfx.components.DataAccessor,\n",
        "              schema: schema_pb2.Schema,\n",
        "              batch_size: int) -> tf.data.Dataset:\n",
        "  \"\"\"Generates features and label for training.\n",
        "\n",
        "  Args:\n",
        "    file_pattern: List of paths or patterns of input tfrecord files.\n",
        "    data_accessor: DataAccessor for converting input to RecordBatch.\n",
        "    schema: schema of the input data.\n",
        "    batch_size: representing the number of consecutive elements of returned\n",
        "      dataset to combine in a single batch\n",
        "\n",
        "  Returns:\n",
        "    A dataset that contains (features, indices) tuple where features is a\n",
        "      dictionary of Tensors, and indices is a single Tensor of label indices.\n",
        "  \"\"\"\n",
        "  return data_accessor.tf_dataset_factory(\n",
        "      file_pattern,\n",
        "      tfxio.TensorFlowDatasetOptions(\n",
        "          batch_size=batch_size, label_key=_LABEL_KEY),\n",
        "      schema=schema).repeat()\n",
        "\n",
        "\n",
        "def _make_keras_model() -> tf.keras.Model:\n",
        "  \"\"\"Creates a DNN Keras model for classifying penguin data.\n",
        "\n",
        "  Returns:\n",
        "    A Keras Model.\n",
        "  \"\"\"\n",
        "  # The model below is built with Functional API, please refer to\n",
        "  # https://www.tensorflow.org/guide/keras/overview for all API options.\n",
        "  inputs = [keras.layers.Input(shape=(1,), name=f) for f in _FEATURE_KEYS]\n",
        "  d = keras.layers.concatenate(inputs)\n",
        "  for _ in range(2):\n",
        "    d = keras.layers.Dense(8, activation='relu')(d)\n",
        "  outputs = keras.layers.Dense(3)(d)\n",
        "\n",
        "  model = keras.Model(inputs=inputs, outputs=outputs)\n",
        "  model.compile(\n",
        "      optimizer=keras.optimizers.Adam(1e-2),\n",
        "      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "      metrics=[keras.metrics.SparseCategoricalAccuracy()])\n",
        "\n",
        "  model.summary(print_fn=logging.info)\n",
        "  return model\n",
        "\n",
        "\n",
        "# NEW: Read `use_gpu` from the custom_config of the Trainer.\n",
        "#      if it uses GPU, enable MirroredStrategy.\n",
        "def _get_distribution_strategy(fn_args: tfx.components.FnArgs):\n",
        "  if fn_args.custom_config.get('use_gpu', False):\n",
        "    logging.info('Using MirroredStrategy with one GPU.')\n",
        "    return tf.distribute.MirroredStrategy(devices=['device:GPU:0'])\n",
        "  return None\n",
        "\n",
        "\n",
        "# TFX Trainer will call this function.\n",
        "def run_fn(fn_args: tfx.components.FnArgs):\n",
        "  \"\"\"Train the model based on given args.\n",
        "\n",
        "  Args:\n",
        "    fn_args: Holds args used to train the model as name/value pairs.\n",
        "  \"\"\"\n",
        "\n",
        "  # This schema is usually either an output of SchemaGen or a manually-curated\n",
        "  # version provided by pipeline author. A schema can also derived from TFT\n",
        "  # graph if a Transform component is used. In the case when either is missing,\n",
        "  # `schema_from_feature_spec` could be used to generate schema from very simple\n",
        "  # feature_spec, but the schema returned would be very primitive.\n",
        "  schema = schema_utils.schema_from_feature_spec(_FEATURE_SPEC)\n",
        "\n",
        "  train_dataset = _input_fn(\n",
        "      fn_args.train_files,\n",
        "      fn_args.data_accessor,\n",
        "      schema,\n",
        "      batch_size=_TRAIN_BATCH_SIZE)\n",
        "  eval_dataset = _input_fn(\n",
        "      fn_args.eval_files,\n",
        "      fn_args.data_accessor,\n",
        "      schema,\n",
        "      batch_size=_EVAL_BATCH_SIZE)\n",
        "\n",
        "  # NEW: If we have a distribution strategy, build a model in a strategy scope.\n",
        "  strategy = _get_distribution_strategy(fn_args)\n",
        "  if strategy is None:\n",
        "    model = _make_keras_model()\n",
        "  else:\n",
        "    with strategy.scope():\n",
        "      model = _make_keras_model()\n",
        "\n",
        "  model.fit(\n",
        "      train_dataset,\n",
        "      steps_per_epoch=fn_args.train_steps,\n",
        "      validation_data=eval_dataset,\n",
        "      validation_steps=fn_args.eval_steps)\n",
        "\n",
        "  # The result of the training should be saved in `fn_args.serving_model_dir`\n",
        "  # directory.\n",
        "  model.save(fn_args.serving_model_dir, save_format='tf')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-LsYx8MpYvPv"
      },
      "source": [
        "Copy the module file to GCS which can be accessed from the pipeline components.\n",
        "\n",
        "Otherwise, you might want to build a container image including the module file\n",
        "and use the image to run the pipeline and AI Platform Training jobs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rMMs5wuNYAbc"
      },
      "outputs": [],
      "source": [
        "!gsutil cp {_trainer_module_file} {MODULE_ROOT}/"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w3OkNz3gTLwM"
      },
      "source": [
        "### Write a pipeline definition\n",
        "\n",
        "We will define a function to create a TFX pipeline. It has the same three\n",
        "Components as in\n",
        "[Simple TFX Pipeline Tutorial](https://www.tensorflow.org/tfx/tutorials/tfx/penguin_simple),\n",
        "but we use a `Trainer` component in the GCP extension module.\n",
        "\n",
        "`tfx.extensions.google_cloud_ai_platform.Trainer` behaves like a regular\n",
        "`Trainer`, but it just moves the computation for the model training to cloud.\n",
        "It launches a custom job in Vertex AI Training service and the trainer\n",
        "component in the orchestration system will just wait until the Vertex AI\n",
        "Training job completes.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "M49yYVNBTPd4"
      },
      "outputs": [],
      "source": [
        "def _create_pipeline(pipeline_name: str, pipeline_root: str, data_root: str,\n",
        "                     module_file: str, serving_model_dir: str, project_id: str,\n",
        "                     region: str, use_gpu: bool) -> tfx.dsl.Pipeline:\n",
        "  \"\"\"Implements the penguin pipeline with TFX.\"\"\"\n",
        "  # Brings data into the pipeline or otherwise joins/converts training data.\n",
        "  example_gen = tfx.components.CsvExampleGen(input_base=data_root)\n",
        "\n",
        "  # NEW: Configuration for Vertex AI Training.\n",
        "  # This dictionary will be passed as `CustomJobSpec`.\n",
        "  vertex_job_spec = {\n",
        "      'project': project_id,\n",
        "      'worker_pool_specs': [{\n",
        "          'machine_spec': {\n",
        "              'machine_type': 'n1-standard-4',\n",
        "          },\n",
        "          'replica_count': 1,\n",
        "          'container_spec': {\n",
        "              'image_uri': 'gcr.io/tfx-oss-public/tfx:{}'.format(tfx.__version__),\n",
        "          },\n",
        "      }],\n",
        "  }\n",
        "  if use_gpu:\n",
        "    # See https://cloud.google.com/vertex-ai/docs/reference/rest/v1/MachineSpec#acceleratortype\n",
        "    # for available machine types.\n",
        "    vertex_job_spec['worker_pool_specs'][0]['machine_spec'].update({\n",
        "        'accelerator_type': 'NVIDIA_TESLA_K80',\n",
        "        'accelerator_count': 1\n",
        "    })\n",
        "\n",
        "  # Trains a model using Vertex AI Training.\n",
        "  # NEW: We need to specify a Trainer for GCP with related configs.\n",
        "  trainer = tfx.extensions.google_cloud_ai_platform.Trainer(\n",
        "      module_file=module_file,\n",
        "      examples=example_gen.outputs['examples'],\n",
        "      train_args=tfx.proto.TrainArgs(num_steps=100),\n",
        "      eval_args=tfx.proto.EvalArgs(num_steps=5),\n",
        "      custom_config={\n",
        "          tfx.extensions.google_cloud_ai_platform.ENABLE_UCAIP_KEY:\n",
        "              True,\n",
        "          tfx.extensions.google_cloud_ai_platform.UCAIP_REGION_KEY:\n",
        "              region,\n",
        "          tfx.extensions.google_cloud_ai_platform.TRAINING_ARGS_KEY:\n",
        "              vertex_job_spec,\n",
        "          'use_gpu':\n",
        "              use_gpu,\n",
        "      })\n",
        "\n",
        "  # Pushes the model to a filesystem destination.\n",
        "  pusher = tfx.components.Pusher(\n",
        "      model=trainer.outputs['model'],\n",
        "      push_destination=tfx.proto.PushDestination(\n",
        "          filesystem=tfx.proto.PushDestination.Filesystem(\n",
        "              base_directory=serving_model_dir)))\n",
        "\n",
        "  components = [\n",
        "      example_gen,\n",
        "      trainer,\n",
        "      pusher,\n",
        "  ]\n",
        "\n",
        "  return tfx.dsl.Pipeline(\n",
        "      pipeline_name=pipeline_name,\n",
        "      pipeline_root=pipeline_root,\n",
        "      components=components)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mJbq07THU2GV"
      },
      "source": [
        "## Run the pipeline on Vertex Pipelines.\n",
        "\n",
        "We will use Vertex Pipelines to run the pipeline as we did in\n",
        "[Simple TFX Pipeline for Vertex Pipelines Tutorial](https://www.tensorflow.org/tfx/tutorials/tfx/gcp/vertex_pipelines_simple)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fAtfOZTYWJu-"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "\n",
        "PIPELINE_DEFINITION_FILE = PIPELINE_NAME + '_pipeline.json'\n",
        "\n",
        "runner = tfx.orchestration.experimental.KubeflowV2DagRunner(\n",
        "    config=tfx.orchestration.experimental.KubeflowV2DagRunnerConfig(),\n",
        "    output_filename=PIPELINE_DEFINITION_FILE)\n",
        "_ = runner.run(\n",
        "    _create_pipeline(\n",
        "        pipeline_name=PIPELINE_NAME,\n",
        "        pipeline_root=PIPELINE_ROOT,\n",
        "        data_root=DATA_ROOT,\n",
        "        module_file=os.path.join(MODULE_ROOT, _trainer_module_file),\n",
        "        serving_model_dir=SERVING_MODEL_DIR,\n",
        "        project_id=GOOGLE_CLOUD_PROJECT,\n",
        "        region=GOOGLE_CLOUD_REGION,\n",
        "        # We will use CPUs only for now.\n",
        "        use_gpu=False))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fWyITYSDd8w4"
      },
      "source": [
        "The generated definition file can be submitted using kfp client."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tI71jlEvWMV7"
      },
      "outputs": [],
      "source": [
        "# docs_infra: no_execute\n",
        "from kfp.v2.google import client\n",
        "\n",
        "pipelines_client = client.AIPlatformClient(\n",
        "    project_id=GOOGLE_CLOUD_PROJECT,\n",
        "    region=GOOGLE_CLOUD_REGION,\n",
        ")\n",
        "\n",
        "_ = pipelines_client.create_run_from_job_spec(PIPELINE_DEFINITION_FILE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "L3k9f5IVQXcQ"
      },
      "source": [
        "Now you can visit the link in the output above or visit\n",
        "'Vertex AI > Pipelines' in\n",
        "[Google Cloud Console](https://console.cloud.google.com/) to see the\n",
        "progress."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DgVvdYPzzW6k"
      },
      "source": [
        "### Run the pipeline using a GPU\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ht0Zpgx3L82g"
      },
      "source": [
        "Vertex AI supports training using various machine types including support for\n",
        "GPUs. See\n",
        "[Machine spec reference](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/MachineSpec#acceleratortype)\n",
        "for available options.\n",
        "\n",
        "We already defined our pipeline to support GPU training. All we need to do is\n",
        "setting `use_gpu` flag to True. Then a pipeline will be created with a machine\n",
        "spec including one NVIDIA_TESLA_K80 and our model training code will use\n",
        "`tf.distribute.MirroredStrategy`.\n",
        "\n",
        "Note that `use_gpu` flag is not a part of the Vertex or TFX API. It is just\n",
        "used to control the training code in this tutorial."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1TwX6bcsLo_g"
      },
      "outputs": [],
      "source": [
        "# docs_infra: no_execute\n",
        "runner.run(\n",
        "    _create_pipeline(\n",
        "        pipeline_name=PIPELINE_NAME,\n",
        "        pipeline_root=PIPELINE_ROOT,\n",
        "        data_root=DATA_ROOT,\n",
        "        module_file=os.path.join(MODULE_ROOT, _trainer_module_file),\n",
        "        serving_model_dir=SERVING_MODEL_DIR,\n",
        "        project_id=GOOGLE_CLOUD_PROJECT,\n",
        "        region=GOOGLE_CLOUD_REGION,\n",
        "        # Updated: Use GPUs. We will use a NVIDIA_TESLA_K80 and \n",
        "        # the model code will use tf.distribute.MirroredStrategy.\n",
        "        use_gpu=True))\n",
        "\n",
        "_ = pipelines_client.create_run_from_job_spec(PIPELINE_DEFINITION_FILE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Xc9XsjlyKoZe"
      },
      "source": [
        "Now you can visit the link in the output above or visit\n",
        "'Vertex AI > Pipelines' in\n",
        "[Google Cloud Console](https://console.cloud.google.com/) to see the\n",
        "progress."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "pknVo1kM2wI2",
        "8F2SRwRLSYGa"
      ],
      "name": "vertex_pipelines_vertex_training.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
